{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 第 10 章 自然语言处理入门"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##10.1 分词"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10.1.1 英文分词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['the', '5', 'biggest', 'countries', 'by', 'population', 'in', '2017', 'are', 'china', 'india', 'united', 'states', 'indonesia', 'and', 'brazil']\n"
     ]
    }
   ],
   "source": [
    "import tensorflow.keras.preprocessing.text as kp_text\n",
    "\n",
    "paragraph = \"The 5 biggest countries by population in 2017 are China, \" \\\n",
    "            \"India, United States, Indonesia, and Brazil.\"\n",
    "processed_text = kp_text.text_to_word_sequence(paragraph)\n",
    "print(processed_text)\n",
    "\n",
    "# 输出如下：\n",
    "# ['the', '5', 'biggest', 'countries', 'by', 'population', 'in', '2017',\n",
    "#  'are', 'china', 'india', 'united', 'states', 'indonesia', 'and', 'brazil']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 10.1.2 中文分词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /var/folders/k_/86_kl60s3vg28cg5x8_b0yb80000gn/T/jieba.cache\n",
      "Loading model cost 0.790 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "我/ 来到/ 北京/ 清华/ 清华大学/ 华大/ 大学\n",
      "我/ 来到/ 北京/ 清华大学\n",
      "他, 来到, 了, 网易, 杭研, 大厦\n",
      "小明, 硕士, 毕业, 于, 中国, 科学, 学院, 科学院, 中国科学院, 计算, 计算所\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "\n",
    "seg_list = jieba.cut(\"我来到北京清华大学\", cut_all=True)\n",
    "print(\"/ \".join(seg_list))  # 全模式\n",
    "\n",
    "seg_list = jieba.cut(\"我来到北京清华大学\", cut_all=False)\n",
    "print(\"/ \".join(seg_list))  # 精确模式\n",
    "\n",
    "seg_list = jieba.cut(\"他来到了网易杭研大厦\")  # 默认是精确模式\n",
    "print(\", \".join(seg_list))\n",
    "\n",
    "seg_list = jieba.cut_for_search(\"小明硕士毕业于中国科学院计算所\")  # 搜索引擎模式\n",
    "print(\", \".join(seg_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 10.2 语言模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 10.2.1 独热编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%% \n"
    }
   },
   "outputs": [],
   "source": [
    "token2idx = {'人工智能': 0, '的': 1, '研究': 2, '可以': 3, '分为': 4, '几个': 5,\n",
    "             '技术': 6, '问题': 7, '是': 8, '一门': 9, '新': 10, '科学': 11}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'人工智能': 1, '的': 2, '技术': 3, '研究': 4, '可以': 5, '分为': 6, '几个': 7, '问题': 8, '是': 9, '一门': 10, '新': 11, '科学': 12}\n",
      "[[1, 2, 4, 5, 6, 7, 3, 8], [1, 9, 10, 11, 2, 3, 12]]\n",
      "[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.preprocessing.text import Tokenizer\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "samples = [\"人工智能 的 研究 可以 分为 几个 技术 问题\", \"人工智能 是 一门 新 的 技术 科学\"]\n",
    "\n",
    "tokenizer = Tokenizer()\n",
    "tokenizer.fit_on_texts(samples)\n",
    "\n",
    "print(tokenizer.word_index)\n",
    "\n",
    "sequence = tokenizer.texts_to_sequences(samples)\n",
    "print(sequence)\n",
    "\n",
    "print(to_categorical(sequence[0],\n",
    "                     num_classes=len(tokenizer.word_index)+1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 10.2.2 词嵌入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.102539  0.245417  0.530431  0.180411  0.386218  0.82535   0.530675\n",
      " -0.244311 -0.681341  0.431361 -0.576909 -0.231161  0.401618 -0.22292\n",
      " -0.063599  0.169365  0.282321  0.535496  0.825031 -0.049239  0.327087\n",
      "  0.460255 -0.103537  0.183484  0.578723 -0.01837   0.362876  0.538694\n",
      "  0.035194  0.725767 -0.414246 -0.667796 -0.122474  0.794132 -0.168633\n",
      "  0.284308 -0.475918 -0.161586 -1.174823  0.725582  0.541358  0.473443\n",
      " -0.643121  0.03602  -0.3345   -0.939873 -0.752846 -0.4204    0.300105\n",
      " -0.141766 -0.224159  0.538478 -0.040462 -0.832016  0.049291  0.123678\n",
      " -0.030749  0.070804  0.451031  0.027303 -0.247646  0.021237  0.53187\n",
      " -0.007717 -0.115273  0.268562  0.472504 -0.316475  0.554178  0.109384\n",
      " -0.149233 -0.179639  0.640718 -0.260974  0.125321  0.089919 -0.64585\n",
      "  0.375543  0.253541  0.142418  0.568716 -0.197845 -0.256447  0.389201\n",
      "  0.037378 -0.301636  0.639302 -0.686622 -0.146073 -0.041729 -0.591598\n",
      "  0.347904  0.313897 -0.508647 -0.027256 -0.121522  0.335613 -0.196134\n",
      " -0.465964  0.097934  0.197466  0.108443  0.146774 -0.187733  0.471964\n",
      " -0.132781 -0.150485 -0.411566 -0.077239  0.469825 -0.429396 -0.338848\n",
      " -0.266035 -0.114337 -0.241434 -0.484068  0.421364  0.164792 -0.795331\n",
      "  0.877335  0.282501 -0.049753 -0.248542  0.068796 -0.721903 -0.201949\n",
      "  0.095261 -0.11024   0.697713  0.406093  0.328016  0.208961 -0.180162\n",
      "  0.364711  0.113749 -0.538457 -0.365502  0.439329  0.148537  0.151965\n",
      "  0.243821  0.279266  0.33319   0.110524 -0.139981  0.049252 -0.262786\n",
      "  0.532628  0.002224  0.128616  0.416762 -0.101012 -0.152559 -0.160026\n",
      " -0.826332 -0.201008 -0.07649  -0.473925 -0.25654   0.177313 -0.692837\n",
      "  0.421114 -0.203732 -0.440396 -0.651813  0.592739 -0.310016  0.489182\n",
      " -0.472919  0.077262  0.594825 -0.174081  0.545577 -0.028856 -0.02753\n",
      " -0.008633  0.618105  0.403909 -0.295179  0.325014  0.141281 -0.140456\n",
      "  0.272524  0.051696  0.270376 -0.029866 -0.481393 -0.492252 -0.571477\n",
      " -0.550171 -0.413476 -0.33778  -1.049183 -0.029875 -0.518371  0.626629\n",
      " -0.205087  0.204765  0.105452 -0.335756 -0.208589  0.316677 -0.002744\n",
      "  0.875735 -0.479659  0.296657  0.150437  0.146031 -0.258952  0.140439\n",
      "  0.374487  0.258912  0.041128  0.147469  0.467683 -0.127148  0.66636\n",
      " -0.454136 -0.311296 -0.123492  0.705563 -0.072737  0.068641  0.431099\n",
      " -0.699881 -0.736209  0.086087  0.063452 -0.105587 -0.195417 -0.715053\n",
      "  0.159465  0.837999 -0.158468 -0.560022 -0.165479  0.347496 -0.312444\n",
      " -0.711437  0.18064  -0.551316  0.650957  0.837707 -0.140695  0.051283\n",
      "  0.172842  0.490827 -0.692713 -0.051433  0.343387  0.037457  0.50919\n",
      " -1.046952  0.173763  0.79278   0.229149  0.3506   -0.01893  -0.417689\n",
      "  0.485504 -0.316264 -0.639605 -0.698171 -0.989755  0.007368  0.555524\n",
      " -0.239776 -0.08648   0.34435   0.806608 -0.022164 -0.755346  0.528093\n",
      "  0.468402 -0.348716  1.036324  0.981486  0.311801 -0.175254 -0.385483\n",
      " -0.086164  0.510664 -0.349232  0.72494   0.055372  0.2303   -0.065205\n",
      " -0.167833 -0.321994 -0.746782 -0.473484 -0.369037 -0.02137   0.364299\n",
      "  0.504032 -1.147905 -0.334222 -0.107973 -0.210832  0.251977]\n"
     ]
    }
   ],
   "source": [
    "import gensim\n",
    "\n",
    "model_path = 'data/word2vec/sgns.weibo.bigram-char'\n",
    "w2v = gensim.models.KeyedVectors.load_word2vec_format(model_path)\n",
    "vector = w2v['猫咪']\n",
    "print(vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "word list: ['，', '的', '。', '@', '！', '了', '一个', '：', '我们', '、', '有', '】', '一', '是', '和', '？', '微博', '不', '“', '�']\n",
      "word vectors: [[ 0.325048  0.897771 -0.418062 ... -0.47089  -0.11149   0.315512]\n",
      " [ 0.440554 -0.417241  0.1474   ... -0.235042 -0.124166  0.566537]\n",
      " [ 0.516277 -0.002949 -0.456919 ...  0.332332  0.050709  0.471254]\n",
      " ...\n",
      " [ 0.53585   0.344069  0.572108 ...  0.02912   0.009632  0.111846]\n",
      " [ 0.399981  0.189998  0.033076 ...  0.448605 -0.498731  1.144464]\n",
      " [ 0.144157  0.631021  0.475352 ...  0.715005 -0.632832 -0.617412]]\n",
      "most similiar to 猫咪: \n",
      " [('猫咪', 1.0000001192092896), ('狗狗', 0.6089414358139038), ('猫', 0.561676025390625), ('猫猫', 0.5580824613571167), ('喵', 0.4925574064254761), ('喵星人', 0.48239561915397644), ('主人', 0.45755279064178467), ('狗', 0.4399756193161011), ('宝宝', 0.4378788471221924), ('兔子', 0.4181772470474243)]\n",
      "most similiar to 明星: \n",
      " [('名人', 0.43296903371810913), ('笑星', 0.4187435507774353), ('好莱坞', 0.39082780480384827), ('女神', 0.383768767118454), ('大牌', 0.38327860832214355), ('偶像', 0.36212867498397827), ('模特', 0.3587229251861572), ('嘉宾', 0.3569623827934265), ('精英', 0.3542693257331848), ('人物', 0.34731006622314453)]\n"
     ]
    }
   ],
   "source": [
    "# 输出词向量词表前 20 个词语\n",
    "print(f\"word list: {w2v.index2word[:20]}\")\n",
    "\n",
    "# 输出词向量前 20 个词的向量，形状为 (20, 300)\n",
    "print(f\"word vectors: {w2v.vectors[:20]}\")\n",
    "\n",
    "vector = w2v['猫咪']\n",
    "print(f\"most similiar to 猫咪: \\n {w2v.similar_by_vector(vector)}\")\n",
    "\n",
    "print(f\"most similiar to 明星: \\n {w2v.most_similar('明星')}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.layers import Embedding\n",
    "\n",
    "embedding_layer = Embedding(input_dim=1000,  # 标记个数，这个嵌入层总共能嵌入 999 个标记\n",
    "                            output_dim=128)  # 嵌入维度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 10.2.3 从文本到词嵌入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True]\n",
      "[ True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True  True  True  True  True  True  True  True  True]\n"
     ]
    }
   ],
   "source": [
    "import gensim\n",
    "import numpy as np\n",
    "from typing import List\n",
    "from tensorflow import keras\n",
    "\n",
    "model_path = 'data/word2vec/sgns.weibo.bigram-char'\n",
    "# 通常预训练词嵌入会比较大，加载很耗时耗内存资源，当内存资源有限或者需要快速实验时\n",
    "# 可以通过增加一个 limit 参数，只读取特定数量词向量来节省时间和资源\n",
    "# 下面代码只会加载最高频的 5000 个词的向量\n",
    "w2v = gensim.models.KeyedVectors.load_word2vec_format(model_path, limit=5000)\n",
    "\n",
    "token2index = {\n",
    "    '<PAD>': 0, # 由于我们用 0 补全序列，所以补全标记的索引必须为 0\n",
    "    '<UNK>': 1  # 新词标记的索引可以使任何一个，设置为 1 只是为了方便\n",
    "}\n",
    "\n",
    "# 我们遍历预训练词嵌入的词表，加入到我们的标记索引词典\n",
    "for token in w2v.index2word:\n",
    "    token2index[token] = len(token2index)\n",
    "\n",
    "# 初始化一个形状为 [标记总数，预训练向量维度] 的全 0 张量\n",
    "token_vector = np.zeros((len(token2index), w2v.vector_size))\n",
    "# 随机初始化 <UNK> 标记的张量\n",
    "token_vector[1] = np.random.rand(300)\n",
    "# 从索引 2 开始使用预训练的向量\n",
    "token_vector[2:] = w2v.vectors\n",
    "\n",
    "# 通过测试可以确定新构建的标记索引和标记向量映射关系没问题\n",
    "print(token_vector[token2index['成长']] == w2v['成长'])\n",
    "print(token_vector[token2index['市场']] == w2v['市场'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_2 (Embedding)      (None, None, 300)         1500600   \n",
      "=================================================================\n",
      "Total params: 1,500,600\n",
      "Trainable params: 0\n",
      "Non-trainable params: 1,500,600\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# 使用处理过的预训练向量来初始化嵌入层\n",
    "L = keras.layers\n",
    "embedding_layer = L.Embedding(input_dim=len(token2index),  # 标记数量等于词表标记数量\n",
    "                              output_dim=w2v.vector_size,  # 嵌入维度等于预训练向量维度\n",
    "                              weights=[token_vector],        # 使用我们构建的权重张量\n",
    "                              trainable=False)             # 不可训练\n",
    "# 构建一个提取序列向量的模型\n",
    "model = keras.Sequential([\n",
    "    embedding_layer\n",
    "])\n",
    "# 我们不需要训练这个模型，所以这里的损失函数和优化器可以随意设定\n",
    "model.compile('adam', 'sparse_categorical_crossentropy')\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[89, 438, 462, 242, 1]\n"
     ]
    }
   ],
   "source": [
    "def convert_token_2_idx(tokenized_sentence: List[str]) -> List[int]:\n",
    "    \"\"\"转换分词后的标记序列为标记索引序列\n",
    "\n",
    "    如果该标记在词表出现过使用其索引，如果词表不存在，则使用新词标记的索引来替代\n",
    "    Args:\n",
    "        tokenized_sentence: 分词后的序列\n",
    "    Returns:\n",
    "        标记索引序列\n",
    "    \"\"\"\n",
    "    token_ids = []\n",
    "    for token in tokenized_sentence:\n",
    "        token_ids.append(token2index.get(token, token2index['<UNK>']))\n",
    "    return token_ids\n",
    "\n",
    "tokenized_sentence = \"今天 天气 真 不错 ha\".split(' ')\n",
    "print(convert_token_2_idx(tokenized_sentence))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 5, 300)\n"
     ]
    }
   ],
   "source": [
    "sentence_index = convert_token_2_idx(tokenized_sentence)\n",
    "# 将序列索引包含一个样本的批量\n",
    "input_x = np.array([sentence_index])\n",
    "# 使用模型预测\n",
    "sentence_vector = model.predict(input_x)\n",
    "print(sentence_vector.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
